package cn.zfs.springcloud.datasource.aspect;

import cn.zfs.springcloud.datasource.annotation.SwitchDataSource;
import cn.zfs.springcloud.datasource.contextholder.DataSourceContextHolder;
import cn.zfs.springcloud.datasource.contextholder.TransactionContextHolder;
import cn.zfs.springcloud.datasource.contextholder.TransactionNumContextHolder;
import cn.zfs.springcloud.datasource.init.InitData;
import cn.zfs.springcloud.datasource.model.TransactionModel;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import org.springframework.stereotype.Component;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;

import java.util.List;

/**
 * 事务开启,切换数据源,公共方法抽取
 */
@SuppressWarnings({"JavaDoc", "WeakerAccess", "SpringAutowiredFieldsWarningInspection"})
@Component
public class AspectPublic {

    private  Logger logger = LoggerFactory.getLogger(AspectPublic.class);

    @Autowired
    private InitData initData;

    /**
     * 切换数据源
     *
     * @param dataSourceName
     */
    public void switchDataSource(String dataSourceName, String defaultDataSourceName) {
        dataSourceName = "".equals(dataSourceName) ? defaultDataSourceName : dataSourceName;
        logger.debug("切换数据源:{}", dataSourceName);
        DataSourceContextHolder.setDetermineCurrentLookupKey(dataSourceName);
    }

    /**
     * 开启事务
     * @param applicationContext
     * @param dataSourceName 数据源名称
     * @param openTransaction 是否打开事务
     * @param openTransactionNum 开始事务个数
     */
    public void openTransaction(ApplicationContext applicationContext,
                                String dataSourceName, boolean openTransaction, int openTransactionNum, boolean thisMethodOpenTrancation) {
        logger.debug("是否开启事务:{}", openTransaction);
        if (openTransaction) {
            if (StringUtils.isBlank(dataSourceName)) {
                dataSourceName = initData.getDefaultDataSourceName();
            }
            logger.debug("开启事务:{}", dataSourceName);
            DataSourceTransactionManager transactionManager =
                    applicationContext.getBean("transactionManager", DataSourceTransactionManager.class);
            DefaultTransactionDefinition def = new DefaultTransactionDefinition(); // 获取事务定义
            def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); // 设置事务隔离级别，开启新事务
            TransactionStatus status = transactionManager.getTransaction(def); // 获得事务状态
            TransactionContextHolder.setTransactionModel(new TransactionModel(dataSourceName, transactionManager, status)); // 将事务状态存储进threadlocal
            logger.debug("为数据源:{},开启事务", dataSourceName);
        }
        logger.debug("设置需要开启事务的个数:{}", openTransactionNum);
        TransactionNumContextHolder.setTransactionNum(openTransactionNum, thisMethodOpenTrancation);
    }

    /**
     * 提交事务并清空线程
     */
    public void commitTransactionClearThreadLocal(){
        TransactionNumContextHolder.setTransactionNum(InitData.NEGATIVEONE, false); // 可以提交事务,事务总数-1
        Integer transactionNum = TransactionNumContextHolder.getTransactionNum(); // 获取开启的事务个数
        List<TransactionModel> transactionModels = TransactionContextHolder.getTransactionModels(); // 获取开启事务的数据源个数
        if (transactionModels != null && InitData.ZERO == transactionNum) { // 比较开启事务的数据源个数和需要开启事务的个数是否相同
            logger.debug("开始提交事务提交事务");
            for (int i = transactionModels.size() - InitData.ONE; i >= InitData.ZERO; i--) {
                transactionModels.get(i).getDataSourceTransactionManager().commit(transactionModels.get(i).getTransactionStatus());
            }
            clearThreadLocal(); // 清空线程
        }
    }

    /**
     * 事务回滚
     */
    public void rollbackTransaction(){
        List<TransactionModel> transactionModels = TransactionContextHolder.getTransactionModels();
        if (transactionModels != null && transactionModels.size() > 0) {
            logger.info("事务回滚");
            for (int i = transactionModels.size() - InitData.ONE; i >= 0; i--) {
                TransactionModel transactionModel = transactionModels.get(i);
                if (transactionModel.getTransactionStatus() != null) {
                    transactionModel.getDataSourceTransactionManager().rollback(transactionModel.getTransactionStatus());
                }
            }
        }
    }

    /**
     * 清空threadLocal
     */
    public void clearThreadLocal() {
        logger.debug("清空事务threadlocal");
        TransactionContextHolder.removeTransactionModel();
        logger.debug("清空事务个数threadlocal");
        TransactionNumContextHolder.removeTransactionNum();
    }
}
